(ns thi.ng.geom.ptf
  #?(:cljs (:require-macros [thi.ng.math.macros :as mm]))
  (:require
   [thi.ng.geom.core :as g]
   [thi.ng.geom.vector :as v :refer [vec2 vec3 V3]]
   [thi.ng.geom.matrix :refer [matrix44 M44]]
   [thi.ng.geom.attribs :as attr]
   [thi.ng.geom.basicmesh :as bm]
   [thi.ng.math.core :as m :refer [*eps* TWO_PI]]
   #?(:clj [thi.ng.math.macros :as mm])))

;; Parallel Transport Frames
;; 
;; http://en.wikipedia.org/wiki/Parallel_transport
;; http://www.cs.indiana.edu/pub/techreports/TR425.pdf

(defn compute-tangents
  "Takes a seq of path points and computes their tangents (pairwise
  normalized direction vectors)."
  [points]
  (mapv
   (fn [pair] (m/normalize (m/- (nth pair 1) (first pair))))
   (partition 2 1 points)))

(defn compute-frame
  [tangents norms bnorms i]
  (let [ii (dec i)
        p  (nth tangents ii)
        q  (nth tangents i)
        a  (m/cross p q)
        n  (if-not (m/delta= 0.0 (m/mag-squared a))
             (let [theta (Math/acos (m/clamp-normalized (m/dot p q)))]
               (g/transform-vector
                (g/rotate-around-axis M44 (m/normalize a) theta)
                (nth norms ii)))
             (nth norms ii))]
    [n (m/cross q n)]))

(defn compute-first-frame
  "Takes a tangent vector (normalized dir between 2 path points) and
  returns a suitable normal and binormal vector."
  [t]
  (let [t' (m/abs t)
        i  (if (< (v/x t') (v/y t')) 0 1)
        i  (if (< (v/z t') (nth t' i)) 2 i)
        n  (m/cross t (m/normalize (m/cross t (assoc V3 i 1.0))))]
    [n (m/cross t n)]))

(defn compute-frames
  "Takes a seq of 3d points and returns vector of its PTFs. The result
  is a vector of 4 elements: [points tangents normals binormals]."
  [points]
  (let [tangents (compute-tangents points)
        [n b]    (compute-first-frame (first tangents))
        num      (count tangents)]
    (loop [norms (transient [n]), bnorms (transient [b]), i 1]
      (if (< i num)
        (let [f (compute-frame tangents norms bnorms i)]
          (recur (conj! norms (first f)) (conj! bnorms (nth f 1)) (inc i)))
        [points tangents (persistent! norms) (persistent! bnorms)]))))

(defn align-frames
  "Takes a vector of PTFs (as returned by compute-frames) and
  re-aligns all frames such that the orientation of the first and last
  are the same. Returns updated PTFs."
  [[points tangents norms bnorms]]
  (let [num   (count tangents)
        a     (first norms)
        b     (peek norms)
        theta (-> (m/dot a b) (m/clamp-normalized) (Math/acos) (/ (dec num)))
        theta (if (> (m/dot (first tangents) (m/cross a b)) 0.0) (- theta) theta)]
    (loop [norms (transient norms), bnorms (transient bnorms), i 1]
      (if (< i num)
        (let [t (nth tangents i)
              n (-> M44
                    (g/rotate-around-axis t (* theta i))
                    (g/transform-vector (nth norms i)))
              b (m/cross t n)]
          (recur (assoc! norms i n) (assoc! bnorms i b) (inc i)))
        [points tangents (persistent! norms) (persistent! bnorms)]))))

(defn sweep-point
  "Takes a path point, a PTF normal & binormal and a profile point.
  Returns profile point projected on path (point)."
  [p n b [qx qy]]
  (vec3
   (mm/madd qx (v/x n) qy (v/x b) (v/x p))
   (mm/madd qx (v/y n) qy (v/y b) (v/y p))
   (mm/madd qx (v/z n) qy (v/z b) (v/z p))))

(defn sweep-profile
  [profile attribs opts [points _ norms bnorms]]
  (let [{:keys [close? loop?] :or {close? true}} opts
        frames     (map vector points norms bnorms)
        tx         (fn [[p n b]] (mapv #(sweep-point p n b %) profile))
        frame0     (tx (first frames))
        nprof      (count profile)
        nprof1     (inc nprof)
        numf       (dec (count points))
        attr-state {:du (/ 1.0 nprof) :dv (/ 1.0 numf)}
        frames     (if loop?
                     (concat (next frames) [(first frames)])
                     (next frames))]
    (->> frames
         (reduce
          (fn [[faces prev i fid] frame]
            (let [curr  (tx frame)
                  curr  (if close? (conj curr (first curr)) curr)
                  atts  (assoc attr-state :v (double (/ i numf)))
                  faces (->> (interleave
                              (partition 2 1 prev)
                              (partition 2 1 curr))
                             (partition 2)
                             (into faces
                                   (map-indexed
                                    (fn [j [a b]]
                                      (attr/generate-face-attribs
                                       [(nth a 0) (nth a 1) (nth b 1) (nth b 0)]
                                       (+ fid j)
                                       attribs
                                       (assoc atts :u (double (/ j nprof))))))))]
              [faces curr (inc i) (+ fid nprof1)]))
          [[] (if close? (conj frame0 (first frame0)) frame0) 0 0])
         (first))))

(defn sweep-profile-faces
  "Takes a list of path points and seq of 2D profile vertices to sweep
  along path and optional map of sweep & vertex attrib options.
  Returns vector of raw faces (each face a vector of [verts attribs])."
  ([points profile]
   (sweep-profile-faces points profile nil))
  ([points profile {:keys [attribs align?] :as opts}]
   (let [frames (compute-frames points)
         frames (if align? (align-frames frames) frames)]
     (sweep-profile profile attribs opts frames))))

(defn sweep-mesh
  "Like sweep-profile-faces, but returns result as mesh."
  ([points profile]
   (sweep-mesh points profile nil))
  ([points profile opts]
   (g/into (or (get opts :mesh) (bm/basic-mesh)) (sweep-profile-faces points profile opts))))

(defn sweep-strand
  [[p _ n b] r theta delta profile opts]
  (let [pfn (if (fn? r)
              #(vec2 (r %) (mm/madd %2 delta theta))
              #(vec2 r (mm/madd %2 delta theta)))]
    (-> (map-indexed
         #(->> (pfn % %2)
               (g/as-cartesian)
               (sweep-point (p %2) (n %2) (b %2)))
         (range (dec (count p))))
        (sweep-mesh profile (merge {:align? true} opts)))))

(defn sweep-strands
  [base r strands twists profile opts]
  (let [delta (/ (* twists TWO_PI) (dec (count (first base))))]
    (->> (m/norm-range strands)
         (butlast)
         (#?(:clj pmap :cljs map)
          #(sweep-strand base r (* % TWO_PI) delta profile opts)))))

(defn sweep-strand-mesh
  ([base r strands twists profile]
   (sweep-strand-mesh base r strands twists profile nil))
  ([base r strands twists profile opts]
   (let [meshes (sweep-strands base r strands twists profile opts)]
     (if-not (get opts :mesh)
       (reduce g/into (bm/basic-mesh) meshes)
       (last meshes)))))
